{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 34、Swin Transformer论文精讲及其PyTorch逐行实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import Tensor\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1、如何从图像获取embedding\n",
    "###  使用 unfold 分块\n",
    "\n",
    "- 基于pytorch `unfold`的API来将图片进行分块，也就是模仿卷积的思路，设置`kernel_size=stride=patch_size`，得到分块后的图片\n",
    "- 得到格式为`[bs, num_patch, patch_depth]`的张量\n",
    "- 将张量与形状为`[patch_depth, model_dim_C]`的权重矩阵进行乘法操作，即可得到形状为`[bs, num_patch, model_dim_C]`的patch embedding\n",
    "\n",
    "###  使用卷积\n",
    "\n",
    "- `patch_depth`是等于`input_channel*patch_size*patch_size`\n",
    "- `model_din_C`相当于二维卷积的输出通道数目\n",
    "- 将形状为`[patch_depth, model_dim_C]`的权重矩阵转换为`[model_dim_C, input_channel, patch_size, patch_size]`的卷积核\n",
    "- 调用PyTorch的`conv2d` API得到卷积的输出张量,形状为`[bs, output_channel, height, width]`\n",
    "- 转换为`[bs, num_patch, model_dim_C]`的格式,即为`patch embedding`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 分别使用 unfold 和卷积操作实现 image->embedding 操作\n",
    "def image2emb_naive(img: Tensor, weight, patch_size=16):\n",
    "    assert img.dim() == 4\n",
    "    batch_size, channel, img_width, img_height = img.shape\n",
    "    # shape: batch_size, patched, patch_num\n",
    "    region = F.unfold(img, patch_size, stride=patch_size)\n",
    "    region.transpose_(-1, -2)\n",
    "    print(region.shape)\n",
    "    return region @ weight\n",
    "\n",
    "\n",
    "def image2emb_conv(img: Tensor, kernel, patch_size=16):\n",
    "    \"\"\"基于二维卷积实现patch embedding,embedding 维度就是卷积通道数\"\"\"\n",
    "    conv = F.conv2d(img, kernel, stride=patch_size)\n",
    "    bs, oc, ow, oh = conv.shape\n",
    "    patch_emb = conv.reshape((bs, oc, ow*oh)).transpose(-1, -2) # (bs, ow*oh, oc)\n",
    "    return patch_emb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## 2、如何构建MHSA并计算其复杂度？\n",
    "\n",
    "- 基于输入x进行三个映射分别得到q,k,v\n",
    "    -   此步复杂度为$3LC^2$，其中L为序列长度，C为特征大小\n",
    "-   将q,k,v拆分成多头的形式，注意这里的多头各自计算不影响，所以可以与bs维度进行统一看待\n",
    "-   计算$qk^T$，并考虑可能的掩码，即让无效的两两位置之间的能量为负无穷，掩码是在shift window MHSA中会需要，而在window MHSA中暂不需要\n",
    "    -   此步复杂度为$L^2C$\n",
    "-   计算概率值与v的乘积\n",
    "    -   此步复杂度为$L^2C$\n",
    "-   对输出进行再次映射\n",
    "    -   此步复杂度为$LC^2$\n",
    "-   总体复杂度为$4LC^2 + 2L^2C$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
